This example shows how to use recurrent neural networks (with and without attention) to classify documents.
We use our usual sentiment analysis benchmark.
import torch
from torch import nn
import time
import torchtext
import random
from collections import defaultdict
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn')
Let's define an RNN-based text classifier. We'll apply a bidirectional RNN and then base the classification on the last state in both directions.
We'll optionally use pre-trained embeddings, which are assumed to be stored with the torchtext vocabulary object.
class RNNTextClassifier(nn.Module):
def __init__(self, text_field, class_field, emb_dim, rnn_size, update_pretrained=False):
super().__init__()
voc_size = len(text_field.vocab)
n_classes = len(class_field.vocab)
# Embedding layer.
self.embedding = nn.Embedding(voc_size, emb_dim)
# If we're using pre-trained embeddings, copy them into the model's embedding layer.
if text_field.vocab.vectors is not None:
self.embedding.weight = torch.nn.Parameter(text_field.vocab.vectors,
requires_grad=update_pretrained)
# The RNN module: either a basic RNN, LSTM, or a GRU.
#self.rnn = nn.RNN(input_size=emb_dim, hidden_size=rnn_size,
# bidirectional=True, num_layers=1)
#self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=rnn_size,
# bidirectional=True, num_layers=1)
self.rnn = nn.GRU(input_size=emb_dim, hidden_size=rnn_size,
bidirectional=True, num_layers=1)
# And finally, a linear layer on top of the RNN layer to produce the output.
self.top_layer = nn.Linear(2*rnn_size, n_classes)
def forward(self, texts):
# The words in the documents are encoded as integers. The shape of the documents
# tensor is (max_len, n_docs), where n_docs is the number of documents in this batch,
# and max_len is the maximal length of a document in the batch.
# First look up the embeddings for all the words in the documents.
# The shape is now (max_len, n_docs, emb_dim).
embedded = self.embedding(texts)
# The RNNs return two tensors: one representing the outputs at all positions
# of the final layer, and another representing the final states of each layer.
# In this example, we'll use just the final states.
# NB: for a bidirectional RNN, the final state corresponds to the *last* token
# in the forward direction and the *first* token in the backward direction.
rnn_out, final_state = self.rnn(embedded)
# The shape of final_state is (2*n_layers, n_docs, rnn_size), assuming that
# the RNN is bidirectional.
# We select the top layer's forward and backward states and concatenate them.
top_forward = final_state[-2]
top_backward = final_state[-1]
top_both = torch.cat([top_forward, top_backward], dim=1)
# Apply the linear layer and return the output.
return self.top_layer(top_both)
We now add an attention model, which will compute a weighted average of all the state vectors. These weights are based on an "importance" score computed by a neural network.
We first define the attention model, and then the text classifier that uses it. We describe the attention model in detail, while the classification model should be fairly self-explanatory.
class SimpleAttention(nn.Module):
def __init__(self, rnn_size):
super().__init__()
# This is the neural network that computes the attention scores.
# To keep things simple, we'll use a linear model here.
self.attn_nn = nn.Linear(rnn_size, 1)
def forward(self, rnn_output):
# The input to the attention model is the output from the top layer of the RNN,
# which is a tensor containing the states for each position in each document.
# The shape of this tensor is (n_words, n_docs, rnn_dim).
# First, we apply the attention neural network to each state in the RNN output.
e = self.attn_nn(rnn_output)
# The shape is now (n_words, n_docs, 1). The squeeze method will reshape
# the tensor to (n_words, n_docs).
e = e.squeeze()
# Compute attention weights by applying the softmax over the rows.
# This tensor has the same shape as e.
alpha = torch.softmax(e, dim=0)
# We weigh each RNN state by its attention weight.
# In order to carry out the element-wise multiplication, we need to "flip"
# the tensor so that the RNN state dimension comes first.
# This tensor has the shape (rnn_dim, n_words, n_docs).
weighted = alpha * rnn_output.permute(2, 0, 1)
# Compute a weighted sum of the RNN state vectors. We sum over the word dimension.
# The shape is now (rnn_dim, n_docs).
out = weighted.sum(dim=1)
# "Flip" the tensor back to the shape (n_docs, rnn_dim) so that it fits
# with the linear layer in the text classifier.
return out.t()
class RNNAttentionTextClassifier(nn.Module):
def __init__(self, text_field, class_field, emb_dim, rnn_size, update_pretrained=False):
super().__init__()
voc_size = len(text_field.vocab)
n_classes = len(class_field.vocab)
self.embedding = nn.Embedding(voc_size, emb_dim)
if text_field.vocab.vectors is not None:
self.embedding.weight = torch.nn.Parameter(text_field.vocab.vectors,
requires_grad=update_pretrained)
#self.rnn = nn.RNN(input_size=emb_dim, hidden_size=rnn_size,
# bidirectional=True, num_layers=1)
#self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=rnn_size,
# bidirectional=True, num_layers=1)
self.rnn = nn.GRU(input_size=emb_dim, hidden_size=rnn_size,
bidirectional=True, num_layers=1)
self.attention = SimpleAttention(2*rnn_size)
self.top_layer = nn.Linear(2*rnn_size, n_classes)
def forward(self, texts):
embedded = self.embedding(texts)
rnn_out, final_state = self.rnn(embedded)
# The attention model returns the weighted sum of RNN states for each document.
# The shape is (n_docs, 2*rnn_size).
attention_out = self.attention(rnn_out)
# Apply the linear layer and return the output.
return self.top_layer(attention_out)
We train the classifier and evaluate on the validation set. This code is almost identical to the code that we saw in the CNN lecture.
For the first RNN-based classifier, the performance tends to be a bit lower than for the CNN from Lecture 2. When we add attention, the performance is usually slightly better than the other models, peaking at about 0.86-0.87 on the validation set. However, the performance for both models seems a bit "jumpy" and can vary between runs.
A note on pre-trained word embeddings. We're now using pre-trained embeddings. We use the built-in model glove.6B.100d
that is bundled with torchtext. The first time you run this code, the GloVe model will be downloaded, which will take some time. This downloading step will not be necessary when you run the code subsequently. To use the pre-trained embeddings, they need to be copied into the neural network's parameters (see above).
def read_data(corpus_file, datafields, label_column, doc_start):
with open(corpus_file, encoding='utf-8') as f:
examples = []
for line in f:
columns = line.strip().split(maxsplit=doc_start)
doc = columns[-1]
label = columns[label_column]
examples.append(torchtext.data.Example.fromlist([doc, label], datafields))
return torchtext.data.Dataset(examples, datafields)
def evaluate_validation(scores, loss_function, gold):
guesses = scores.argmax(dim=1)
n_correct = (guesses == gold).sum().item()
return n_correct, loss_function(scores, gold).item()
def main():
TEXT = torchtext.data.Field(sequential=True, tokenize=lambda x: x.split())
LABEL = torchtext.data.LabelField(is_target=True)
datafields = [('text', TEXT), ('label', LABEL)]
random.seed(0)
data = read_data('data/all_sentiment_shuffled.txt', datafields, label_column=1, doc_start=3)
train, valid = data.split([0.8, 0.2])
use_pretrained = True
if use_pretrained:
print('We are using pre-trained word embeddings.')
TEXT.build_vocab(train, vectors="glove.6B.100d")
else:
print('We are training word embeddings from scratch.')
TEXT.build_vocab(train, max_size=10000)
LABEL.build_vocab(train)
# Declare the RNN classifier.
#model = RNNTextClassifier(TEXT, LABEL, emb_dim=100, rnn_size=64, update_pretrained=True)
model = RNNAttentionTextClassifier(TEXT, LABEL, emb_dim=100, rnn_size=64, update_pretrained=True)
device = 'cuda'
model.to(device)
train_iterator = torchtext.data.BucketIterator(
train,
device=device,
batch_size=128,
sort_key=lambda x: len(x.text),
repeat=False,
train=True,
sort=True)
valid_iterator = torchtext.data.BucketIterator(
valid,
device=device,
batch_size=128,
sort_key=lambda x: len(x.text),
repeat=False,
train=False,
sort=True)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0025, weight_decay=1e-4)
train_batches = list(train_iterator)
valid_batches = list(valid_iterator)
history = defaultdict(list)
for i in range(25):
t0 = time.time()
loss_sum = 0
n_batches = 0
model.train()
for batch in train_batches:
scores = model(batch.text)
loss = loss_function(scores, batch.label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
n_batches += 1
train_loss = loss_sum / n_batches
history['train_loss'].append(train_loss)
n_correct = 0
n_valid = len(valid)
loss_sum = 0
n_batches = 0
model.eval()
for batch in valid_batches:
scores = model(batch.text)
n_corr_batch, loss_batch = evaluate_validation(scores, loss_function, batch.label)
loss_sum += loss_batch
n_correct += n_corr_batch
n_batches += 1
val_acc = n_correct / n_valid
val_loss = loss_sum / n_batches
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
t1 = time.time()
print(f'Epoch {i+1}: train loss = {train_loss:.4f}, val loss = {val_loss:.4f}, val acc: {val_acc:.4f}, time = {t1-t0:.4f}')
plt.plot(history['train_loss'])
plt.plot(history['val_loss'])
plt.plot(history['val_acc'])
plt.legend(['training loss', 'validation loss', 'validation accuracy'])
main()